use bitflags::bitflags;

use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};

pub struct OffloadMetadata {
    pub payload_size: u64,
    pub mode: MappingFlags,
}

bitflags! {
    /// Mirrors `OpenMPOffloadMappingFlags` from Clang/OpenMP.
    #[derive(Debug, Copy, Clone)]
    #[repr(transparent)]
    pub struct MappingFlags: u64 {
        /// No flags.
        const NONE           = 0x0;
        /// Allocate memory on the device and move data from host to device.
        const TO             = 0x01;
        /// Allocate memory on the device and move data from device to host.
        const FROM           = 0x02;
        /// Always perform the requested mapping action, even if already mapped.
        const ALWAYS         = 0x04;
        /// Delete the element from the device environment, ignoring ref count.
        const DELETE         = 0x08;
        /// The element being mapped is a pointer-pointee pair.
        const PTR_AND_OBJ    = 0x10;
        /// The base address should be passed to the target kernel as argument.
        const TARGET_PARAM   = 0x20;
        /// The runtime must return the device pointer.
        const RETURN_PARAM   = 0x40;
        /// The reference being passed is a pointer to private data.
        const PRIVATE        = 0x80;
        /// Pass the element by value.
        const LITERAL        = 0x100;
        /// Implicit map (generated by compiler, not explicit in code).
        const IMPLICIT       = 0x200;
        /// Hint to allocate memory close to the target device.
        const CLOSE          = 0x400;
        /// Reserved (0x800 in OpenMP for XLC compatibility).
        const RESERVED       = 0x800;
        /// Require that the data is already allocated on the device.
        const PRESENT        = 0x1000;
        /// Increment/decrement a separate ref counter (OpenACC compatibility).
        const OMPX_HOLD      = 0x2000;
        /// Used for non-contiguous list items in target update.
        const NON_CONTIG     = 0x100000000000;
        /// 16 MSBs indicate membership in a struct.
        const MEMBER_OF      = 0xffff000000000000;
    }
}

impl OffloadMetadata {
    pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
        OffloadMetadata {
            payload_size: get_payload_size(tcx, ty),
            mode: MappingFlags::from_ty(tcx, ty),
        }
    }
}

// FIXME(Sa4dUs): implement a solid logic to determine the payload size
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
    match ty.kind() {
        ty::RawPtr(inner, _) | ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
        _ => tcx
            .layout_of(PseudoCanonicalInput {
                typing_env: TypingEnv::fully_monomorphized(),
                value: ty,
            })
            .unwrap()
            .size
            .bytes(),
    }
}

impl MappingFlags {
    fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
        use rustc_ast::Mutability::*;

        match ty.kind() {
            ty::Bool
            | ty::Char
            | ty::Int(_)
            | ty::Uint(_)
            | ty::Float(_)
            | ty::Adt(_, _)
            | ty::Tuple(_)
            | ty::Array(_, _)
            | ty::Alias(_, _)
            | ty::Param(_) => MappingFlags::TO,

            ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,

            ty::RawPtr(_, Mut) | ty::Ref(_, _, Mut) => MappingFlags::TO | MappingFlags::FROM,

            ty::Slice(_) | ty::Str | ty::Dynamic(_, _) => MappingFlags::TO | MappingFlags::FROM,

            ty::Foreign(_) | ty::Pat(_, _) | ty::UnsafeBinder(_) => {
                MappingFlags::TO | MappingFlags::FROM
            }

            ty::FnDef(_, _)
            | ty::FnPtr(_, _)
            | ty::Closure(_, _)
            | ty::CoroutineClosure(_, _)
            | ty::Coroutine(_, _)
            | ty::CoroutineWitness(_, _)
            | ty::Never
            | ty::Bound(_, _)
            | ty::Placeholder(_)
            | ty::Infer(_)
            | ty::Error(_) => {
                tcx.dcx()
                    .span_err(rustc_span::DUMMY_SP, format!("type `{ty:?}` cannot be offloaded"));
                MappingFlags::empty()
            }
        }
    }
}
